{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# DenseNet\n",
    "因为 ResNet 提出了跨层链接的思想，这直接影响了随后出现的卷积网络架构，其中最有名的就是 cvpr 2017 的 best paper，DenseNet。\n",
    "\n",
    "DenseNet 和 ResNet 不同在于 ResNet 是跨层求和，而 DenseNet 是跨层将特征在通道维度进行拼接，下面可以看看他们两者的图示\n",
    "\n",
    "![](https://ws4.sinaimg.cn/large/006tNc79ly1fmpvj5vkfhj30uw0anq73.jpg)\n",
    "\n",
    "![](https://ws1.sinaimg.cn/large/006tNc79ly1fmpvj7fxd1j30vb0eyzqf.jpg)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "第一张图是 ResNet，第二张图是 DenseNet，因为是在通道维度进行特征的拼接，所以底层的输出会保留进入所有后面的层，这能够更好的保证梯度的传播，同时能够使用低维的特征和高维的特征进行联合训练，能够得到更好的结果。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "DenseNet 主要由 dense block 构成，下面我们来实现一个 densen block"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-12-22T15:38:31.113030Z",
     "start_time": "2017-12-22T15:38:30.612922Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch import nn\n",
    "from torch.autograd import Variable\n",
    "from torchvision.datasets import CIFAR10"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "首先定义一个卷积块，这个卷积块的顺序是 bn -> relu -> conv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-12-22T15:38:31.121249Z",
     "start_time": "2017-12-22T15:38:31.115369Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def conv_block(in_channel, out_channel):\n",
    "    layer = nn.Sequential(\n",
    "        nn.BatchNorm2d(in_channel),\n",
    "        nn.ReLU(True),\n",
    "        nn.Conv2d(in_channel, out_channel, 3, padding=1, bias=False)\n",
    "    )\n",
    "    return layer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "dense block 将每次的卷积的输出称为 `growth_rate`，因为如果输入是 `in_channel`，有 n 层，那么输出就是 `in_channel + n * growh_rate`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-12-22T15:38:31.145274Z",
     "start_time": "2017-12-22T15:38:31.123363Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class dense_block(nn.Module):\n",
    "    def __init__(self, in_channel, growth_rate, num_layers):\n",
    "        super(dense_block, self).__init__()\n",
    "        block = []\n",
    "        channel = in_channel\n",
    "        for i in range(num_layers):\n",
    "            block.append(conv_block(channel, growth_rate))\n",
    "            channel += growth_rate\n",
    "            \n",
    "        self.net = nn.Sequential(*block)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        for layer in self.net:\n",
    "            out = layer(x)\n",
    "            x = torch.cat((out, x), dim=1)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们验证一下输出的 channel 是否正确"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-12-22T15:38:31.213632Z",
     "start_time": "2017-12-22T15:38:31.147196Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input shape: 3 x 96 x 96\n",
      "output shape: 39 x 96 x 96\n"
     ]
    }
   ],
   "source": [
    "test_net = dense_block(3, 12, 3)\n",
    "test_x = Variable(torch.zeros(1, 3, 96, 96))\n",
    "print('input shape: {} x {} x {}'.format(test_x.shape[1], test_x.shape[2], test_x.shape[3]))\n",
    "test_y = test_net(test_x)\n",
    "print('output shape: {} x {} x {}'.format(test_y.shape[1], test_y.shape[2], test_y.shape[3]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "除了 dense block，DenseNet 中还有一个模块叫过渡层（transition block），因为 DenseNet 会不断地对维度进行拼接， 所以当层数很高的时候，输出的通道数就会越来越大，参数和计算量也会越来越大，为了避免这个问题，需要引入过渡层将输出通道降低下来，同时也将输入的长宽减半，这个过渡层可以使用 1 x 1 的卷积"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-12-22T15:38:31.222120Z",
     "start_time": "2017-12-22T15:38:31.215770Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def transition(in_channel, out_channel):\n",
    "    trans_layer = nn.Sequential(\n",
    "        nn.BatchNorm2d(in_channel),\n",
    "        nn.ReLU(True),\n",
    "        nn.Conv2d(in_channel, out_channel, 1),\n",
    "        nn.AvgPool2d(2, 2)\n",
    "    )\n",
    "    return trans_layer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "验证一下过渡层是否正确"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-12-22T15:38:31.234846Z",
     "start_time": "2017-12-22T15:38:31.224078Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input shape: 3 x 96 x 96\n",
      "output shape: 12 x 48 x 48\n"
     ]
    }
   ],
   "source": [
    "test_net = transition(3, 12)\n",
    "test_x = Variable(torch.zeros(1, 3, 96, 96))\n",
    "print('input shape: {} x {} x {}'.format(test_x.shape[1], test_x.shape[2], test_x.shape[3]))\n",
    "test_y = test_net(test_x)\n",
    "print('output shape: {} x {} x {}'.format(test_y.shape[1], test_y.shape[2], test_y.shape[3]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "最后我们定义 DenseNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-12-22T15:38:31.318822Z",
     "start_time": "2017-12-22T15:38:31.236857Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class densenet(nn.Module):\n",
    "    def __init__(self, in_channel, num_classes, growth_rate=32, block_layers=[6, 12, 24, 16]):\n",
    "        super(densenet, self).__init__()\n",
    "        self.block1 = nn.Sequential(\n",
    "            nn.Conv2d(in_channel, 64, 7, 2, 3),\n",
    "            nn.BatchNorm2d(64),\n",
    "            nn.ReLU(True),\n",
    "            nn.MaxPool2d(3, 2, padding=1)\n",
    "        )\n",
    "        \n",
    "        channels = 64\n",
    "        block = []\n",
    "        for i, layers in enumerate(block_layers):\n",
    "            block.append(dense_block(channels, growth_rate, layers))\n",
    "            channels += layers * growth_rate\n",
    "            if i != len(block_layers) - 1:\n",
    "                block.append(transition(channels, channels // 2)) # 通过 transition 层将大小减半，通道数减半\n",
    "                channels = channels // 2\n",
    "        \n",
    "        self.block2 = nn.Sequential(*block)\n",
    "        self.block2.add_module('bn', nn.BatchNorm2d(channels))\n",
    "        self.block2.add_module('relu', nn.ReLU(True))\n",
    "        self.block2.add_module('avg_pool', nn.AvgPool2d(3))\n",
    "        \n",
    "        self.classifier = nn.Linear(channels, num_classes)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        x = self.block1(x)\n",
    "        x = self.block2(x)\n",
    "        \n",
    "        x = x.view(x.shape[0], -1)\n",
    "        x = self.classifier(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-12-22T15:38:31.654182Z",
     "start_time": "2017-12-22T15:38:31.320788Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "output: torch.Size([1, 10])\n"
     ]
    }
   ],
   "source": [
    "test_net = densenet(3, 10)\n",
    "test_x = Variable(torch.zeros(1, 3, 96, 96))\n",
    "test_y = test_net(test_x)\n",
    "print('output: {}'.format(test_y.shape))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-12-22T15:38:32.894729Z",
     "start_time": "2017-12-22T15:38:31.656356Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from utils import train\n",
    "\n",
    "def data_tf(x):\n",
    "    x = x.resize((96, 96), 2) # 将图片放大到 96 x 96\n",
    "    x = np.array(x, dtype='float32') / 255\n",
    "    x = (x - 0.5) / 0.5 # 标准化，这个技巧之后会讲到\n",
    "    x = x.transpose((2, 0, 1)) # 将 channel 放到第一维，只是 pytorch 要求的输入方式\n",
    "    x = torch.from_numpy(x)\n",
    "    return x\n",
    "     \n",
    "train_set = CIFAR10('./data', train=True, transform=data_tf)\n",
    "train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)\n",
    "test_set = CIFAR10('./data', train=False, transform=data_tf)\n",
    "test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)\n",
    "\n",
    "net = densenet(3, 10)\n",
    "optimizer = torch.optim.SGD(net.parameters(), lr=0.01)\n",
    "criterion = nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-12-22T16:15:38.168095Z",
     "start_time": "2017-12-22T15:38:32.896735Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0. Train Loss: 1.374316, Train Acc: 0.507972, Valid Loss: 1.203217, Valid Acc: 0.572884, Time 00:01:44\n",
      "Epoch 1. Train Loss: 0.912924, Train Acc: 0.681506, Valid Loss: 1.555908, Valid Acc: 0.492286, Time 00:01:50\n",
      "Epoch 2. Train Loss: 0.701387, Train Acc: 0.755794, Valid Loss: 0.815147, Valid Acc: 0.718354, Time 00:01:49\n",
      "Epoch 3. Train Loss: 0.575985, Train Acc: 0.800911, Valid Loss: 0.696013, Valid Acc: 0.759494, Time 00:01:50\n",
      "Epoch 4. Train Loss: 0.479812, Train Acc: 0.836957, Valid Loss: 1.013879, Valid Acc: 0.676226, Time 00:01:51\n",
      "Epoch 5. Train Loss: 0.402165, Train Acc: 0.861413, Valid Loss: 0.674512, Valid Acc: 0.778481, Time 00:01:50\n",
      "Epoch 6. Train Loss: 0.334593, Train Acc: 0.888247, Valid Loss: 0.647112, Valid Acc: 0.791634, Time 00:01:50\n",
      "Epoch 7. Train Loss: 0.278181, Train Acc: 0.907149, Valid Loss: 0.773517, Valid Acc: 0.756527, Time 00:01:51\n",
      "Epoch 8. Train Loss: 0.227948, Train Acc: 0.922714, Valid Loss: 0.654399, Valid Acc: 0.800237, Time 00:01:49\n",
      "Epoch 9. Train Loss: 0.181156, Train Acc: 0.940157, Valid Loss: 1.179013, Valid Acc: 0.685225, Time 00:01:50\n",
      "Epoch 10. Train Loss: 0.151305, Train Acc: 0.950208, Valid Loss: 0.630000, Valid Acc: 0.807951, Time 00:01:50\n",
      "Epoch 11. Train Loss: 0.118433, Train Acc: 0.961077, Valid Loss: 1.247253, Valid Acc: 0.703323, Time 00:01:52\n",
      "Epoch 12. Train Loss: 0.094127, Train Acc: 0.969789, Valid Loss: 1.230697, Valid Acc: 0.723101, Time 00:01:51\n",
      "Epoch 13. Train Loss: 0.086181, Train Acc: 0.972047, Valid Loss: 0.904135, Valid Acc: 0.769284, Time 00:01:50\n",
      "Epoch 14. Train Loss: 0.064248, Train Acc: 0.980359, Valid Loss: 1.665002, Valid Acc: 0.624209, Time 00:01:51\n",
      "Epoch 15. Train Loss: 0.054932, Train Acc: 0.982996, Valid Loss: 0.927216, Valid Acc: 0.774723, Time 00:01:51\n",
      "Epoch 16. Train Loss: 0.043503, Train Acc: 0.987272, Valid Loss: 1.574383, Valid Acc: 0.707377, Time 00:01:52\n",
      "Epoch 17. Train Loss: 0.047615, Train Acc: 0.985154, Valid Loss: 0.987781, Valid Acc: 0.770471, Time 00:01:51\n",
      "Epoch 18. Train Loss: 0.039813, Train Acc: 0.988012, Valid Loss: 2.248944, Valid Acc: 0.631824, Time 00:01:50\n",
      "Epoch 19. Train Loss: 0.030183, Train Acc: 0.991168, Valid Loss: 0.887785, Valid Acc: 0.795392, Time 00:01:51\n"
     ]
    }
   ],
   "source": [
    "train(net, train_data, test_data, 20, optimizer, criterion)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "DenseNet 将残差连接改为了特征拼接，使得网络有了更稠密的连接"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
